{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xR7FJc8T66ea"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5dSZ1TzFgXjm"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "#tf.config.experimental.set_visible_devices([], \"GPU\")\n",
        "\n",
        "import importlib\n",
        "from simulation_research.diffusion import ode_datasets\n",
        "from simulation_research.diffusion import unet\n",
        "from simulation_research.diffusion import samplers\n",
        "from simulation_research.diffusion import diffusion\n",
        "from simulation_research.diffusion import config as cfg\n",
        "from simulation_research.diffusion import train\n",
        "from clu import checkpoint\n",
        "importlib.reload(ode_datasets)\n",
        "importlib.reload(unet)\n",
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iHrnHvwuRoeG"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pickle\n",
        "\n",
        "username=\"finzi\"\n",
        "exp_name = \"test_3\"#\"all_datasets_and_ic\"\n",
        "xid=0\n",
        "workdir = \"/home/finzi/xm/test_diffusion/datasets_all2/43500559/1\"\n",
        "\n",
        "with tf.io.gfile.Open(os.path.join(workdir,'config.pickle'), \"rb\") as f:\n",
        "  config = pickle.load(f)\n",
        "with tf.io.gfile.Open(os.path.join(workdir,'data_std.pickle'), \"rb\") as f:\n",
        "  data_std = pickle.load(f)\n",
        "# appending the checkpoint folder\n",
        "checkpoint_dir = os.path.join(workdir, \"checkpoints\")\n",
        "\n",
        "# added the checkpoint\n",
        "ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, {}, max_to_keep=2)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pPI3e_CZsygP"
      },
      "outputs": [],
      "source": [
        "#config.dataset=\"LorenzDataset\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PIAv-4h-FxqW"
      },
      "outputs": [],
      "source": [
        "config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cT6pRtJAS-PB"
      },
      "outputs": [],
      "source": [
        "# config = cfg.get_config()\n",
        "# config.ic_conditioning=False\n",
        "# config.dataset='FitzHughDataset'\n",
        "\n",
        "from jax import random\n",
        "\n",
        "key = random.PRNGKey(config.seed)\n",
        "# Construct the dataset\n",
        "timesteps = 60\n",
        "ds = getattr(ode_datasets, config.dataset)(N=config.ds + config.bs)\n",
        "Zs = ds.Zs[config.bs:, :timesteps]  # pylint: disable=invalid-name\n",
        "test_x = ds.Zs[:config.bs, :timesteps]\n",
        "T_long = ds.T_long[:timesteps]  # pylint: disable=invalid-name\n",
        "dataset = tf.data.Dataset.from_tensor_slices(Zs)\n",
        "dataiter = dataset.shuffle(len(dataset)).batch(config.bs).as_numpy_iterator\n",
        "assert Zs.shape[1] == timesteps\n",
        "\n",
        "# initialize the model\n",
        "x = test_x  # (bs, N, C)\n",
        "modelconfig = unet.unet_64_config(\n",
        "    x.shape[-1], base_channels=config.channels, attention=config.attention)\n",
        "model = unet.UNet(modelconfig)\n",
        "noise = getattr(diffusion, config.noisetype)\n",
        "difftype = getattr(diffusion, config.difftype)(noise)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y2SgoTAmTfMX"
      },
      "outputs": [],
      "source": [
        "dataloader= dataiter\n",
        "x = next(dataloader())\n",
        "t = np.random.rand(x.shape[0])\n",
        "cond_fn = lambda x: (x[:, :3] if config.ic_conditioning else None)\n",
        "key = random.PRNGKey(config.seed)\n",
        "key, init_seed = random.split(key)\n",
        "params = model.init(init_seed, x=x, t=t, train=False, cond=cond_fn(x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WtZgEjuvUfQs"
      },
      "outputs": [],
      "source": [
        "ema_params = ckpt.restore(params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6XFchwwNvJ27"
      },
      "outputs": [],
      "source": [
        "def score(params,\n",
        "            x,\n",
        "            t,\n",
        "            train,\n",
        "            cond = None):\n",
        "    \"\"\"Score function with appropriate input and output scaling.\"\"\"\n",
        "    # scaling is equivalent to that in https://arxiv.org/abs/2206.00364\n",
        "    sigma, scale = diffusion.unsqueeze_like(x, difftype.sigma(t), difftype.scale(t))\n",
        "    input_scale = 1 / jnp.sqrt(sigma**2 + (scale * data_std)**2)\n",
        "    cond = cond / data_std if cond is not None else None\n",
        "    out = model.apply(params, x=x * input_scale, t=t, train=train, cond=cond)\n",
        "    return out / jnp.sqrt(sigma**2 + scale**2 * data_std**2)\n",
        "\n",
        "@jax.jit\n",
        "def score_out(x,t,cond=None) -\u003e jnp.ndarray:\n",
        "  if not hasattr(t, 'shape') or not t.shape:\n",
        "    t = jnp.ones(x.shape[0]) * t\n",
        "  return score(ema_params, x, t, train=False, cond=cond)\n",
        "\n",
        "score_fn = score_out\n",
        "from functools import partial\n",
        "eval_scorefn = partial(score_out,cond=cond_fn(test_x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u-LP9ElYthCl"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oQEMYm0CupO-"
      },
      "outputs": [],
      "source": [
        "diff =difftype\n",
        "#sde_samples = samplers.sde_sample(diff, eval_scorefn, key, test_x.shape,nsteps=1000)\n",
        "#ode_samples = samplers.discrete_ode_sample(diff, eval_scorefn, key, test_x.shape,nsteps=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TR_-SXuxkc2T"
      },
      "outputs": [],
      "source": [
        "def lorenz_C(x):\n",
        "    fourier_mag = jnp.abs(jnp.fft.rfft(x[...,0],axis=-1))\n",
        "    return -(fourier_mag[...,1:].mean(-1)-.6)\n",
        "\n",
        "def fitz_C(x):\n",
        "    C = jnp.max(x[...,:2].mean(-1),-1)-2.5\n",
        "    return C\n",
        "\n",
        "def pendulum_C(x):\n",
        "    raise NotImplementedError\n",
        "\n",
        "constraints = {'FitzHughDataset':fitz_C,\n",
        "          'LorenzDataset': lorenz_C,\n",
        "          'NPendulumDataset':pendulum_C,\n",
        "          }\n",
        "event_constraint = constraints[config.dataset]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6WJ8IXIlqqaD"
      },
      "outputs": [],
      "source": [
        "plt.plot(jnp.abs(jnp.fft.rfft(test_x[0,:,2]))[1:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "omANnmkLCYef"
      },
      "outputs": [],
      "source": [
        "event_scores = samplers.event_scores(diff,score_fn, event_constraint, reg=1e-3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AbZOUGqJDl33"
      },
      "outputs": [],
      "source": [
        "sde_event_samples = samplers.sde_sample(diff, event_scores, key, test_x.shape,nsteps=1000)\n",
        "#ode_event_samples = samplers.discrete_ode_sample(diff, event_scores, key, test_x.shape,nsteps=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "__KZi4hy5bZV"
      },
      "outputs": [],
      "source": [
        "sde_samples = samplers.sde_sample(diff, score_fn, key, test_x.shape,nsteps=1000)\n",
        "#ode_samples = samplers.discrete_ode_sample(diff, score_fn, key, test_x.shape,nsteps=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XyMymRAx5nJH"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q5ZQucl_BO3X"
      },
      "outputs": [],
      "source": [
        "T = ds.T_long[:timesteps]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3jj3VZr7H7Gj"
      },
      "outputs": [],
      "source": [
        "event_distribution = event_constraint(Zs)\n",
        "events_train = Zs[event_constraint(Zs)\u003e0]\n",
        "events_test = test_x[event_constraint(test_x)\u003e0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8HtkSbBxBaDg"
      },
      "outputs": [],
      "source": [
        "plt.plot(T,sde_samples[event_constraint(sde_samples)\u003e0][:5,:,0].T)\n",
        "plt.xlabel(r'$\\tau$')\n",
        "plt.ylabel('x')\n",
        "plt.title('Example model events')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "exHajRcgsVQS"
      },
      "outputs": [],
      "source": [
        "plt.plot(T,test_x[10:15,:,0].T)\n",
        "plt.xlabel(r'$\\tau$')\n",
        "plt.ylabel('x')\n",
        "plt.title('Data samples')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t8R6fAl_BbAD"
      },
      "outputs": [],
      "source": [
        "plt.plot(T,events_test[:5,:,0].T)\n",
        "plt.xlabel(r'$\\tau$')\n",
        "plt.ylabel('x')\n",
        "plt.title('Example events in dataset')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WAfhJo2lyJ9K"
      },
      "outputs": [],
      "source": [
        "T = ds.T[:timesteps]\n",
        "plt.plot(T,events_test[0,:,:2].mean(-1))\n",
        "plt.plot(T,2*np.ones_like(T),color='k')\n",
        "plt.xlabel(r'Time ($\\tau$)')\n",
        "plt.ylabel(r'$x(\\tau)$')\n",
        "plt.legend(['example spike', 'our cutoff y'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yG97VN-M4Cdb"
      },
      "outputs": [],
      "source": [
        "plt.hist(np.array(event_constraint(sde_samples)),bins=80,color='red',density=True,alpha=.5)\n",
        "#plt.hist(np.array(ode_samples[:,:timesteps,:2].mean(-1).max(-1)),bins=100,color='g',density=True,alpha=.2)\n",
        "plt.hist(np.array(event_constraint(ds.Zs[:4000,:timesteps])),bins=80,color='y',density=True,alpha=.5)\n",
        "\n",
        "#plt.yscale('log')\n",
        "#plt.xlabel(r'$\\max_\\tau x(\\tau)$')\n",
        "plt.xlabel(r'$.6-||F[x]_{1:}||_1$')\n",
        "plt.ylabel('Density')\n",
        "plt.ylim(1e-2,2.5)\n",
        "plt.axvline(x=0,color='k')\n",
        "plt.legend(['y cutoff','Diffusion samples','True distribution'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XsbCRAEzf1pq"
      },
      "outputs": [],
      "source": [
        "sde_event_samples.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GW2MLjSQnejk"
      },
      "outputs": [],
      "source": [
        "plt.hist(np.array(event_constraint(sde_event_samples[:,:])),bins=100,color='red',density=True,alpha=.5)\n",
        "#plt.hist(np.array(ode_samples[:,:timesteps,:2].mean(-1).max(-1)),bins=100,color='g',density=True,alpha=.2)\n",
        "plt.hist(np.array(event_constraint(events_train[:,:timesteps])),bins=50,color='y',density=True,alpha=.5)\n",
        "\n",
        "#plt.yscale('log')\n",
        "#plt.xlabel(r'$\\max_\\tau x(\\tau)$')\n",
        "plt.axvline(0,color='k')\n",
        "plt.xlabel(r'$.6-||F[x]_{1:}||_1$')\n",
        "plt.ylabel('density')\n",
        "plt.ylim(1e-2,6)\n",
        "plt.legend(['y cutoff','Conditional diffusion samples x|E','True distribution x|E'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yoe7518sFAZm"
      },
      "outputs": [],
      "source": [
        "import seaborn\n",
        "seaborn.kdeplot(np.array(sde_event_samples[:,:timesteps,:2].mean(-1).max(-1)))\n",
        "seaborn.kdeplot(np.array(events_train[:,:timesteps,:2].mean(-1).max(-1)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pbb6r6YrEku1"
      },
      "outputs": [],
      "source": [
        "(event_constraint(sde_event_samples)\u003e0).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tJa1lheuI0MK"
      },
      "outputs": [],
      "source": [
        "sde_events2 = sde_event_samples[event_constraint(sde_event_samples)\u003e0,:,0]\n",
        "plt.plot(T,sde_events2[:5].T)\n",
        "plt.xlabel(r'$\\tau$')\n",
        "plt.ylabel('x')\n",
        "plt.title('x|E conditional model samples')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eVWIg16zuZna"
      },
      "outputs": [],
      "source": [
        "ds.animate(sde_event_samples[event_constraint(sde_event_samples)\u003e0][1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dUpqVDScusVI"
      },
      "outputs": [],
      "source": [
        "ds.animate(test_x[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qXXu_fh2hkP7"
      },
      "outputs": [],
      "source": [
        "(event_constraint(Zs)\u003e0).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MWLSF1K3fchZ"
      },
      "outputs": [],
      "source": [
        "jnp.exp(-10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YT2uA6ImKWtJ"
      },
      "outputs": [],
      "source": [
        "type((jnp.ones(3)*.2).sum())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e7WNgPSDKLqw"
      },
      "outputs": [],
      "source": [
        "logprob,logprob_std = samplers.marginal_logprob(diff, score_fn, event_constraint, test_x[0].shape,nsteps=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K6HHuMYJLDLh"
      },
      "outputs": [],
      "source": [
        "conditional_likelihood = samplers.discrete_time_likelihood(diff, event_scores, sde_event_samples[:2])\n",
        "unconditional_likelihood = samplers.discrete_time_likelihood(diff, scores_fn, sde_event_samples[:2])\n",
        "print(conditional_likelihood,unconditional_likelihood)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wmx5wGuGaFi0"
      },
      "outputs": [],
      "source": [
        "print(logprob,logprob_std)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0ew63yhpgvzJ"
      },
      "outputs": [],
      "source": [
        "jnp.exp(-logprob)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-z8LRKMbJrS9"
      },
      "outputs": [],
      "source": [
        "plt.plot(ode_event_samples[:5,:,:2].mean(-1).T)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_nknGmie_WmM"
      },
      "outputs": [],
      "source": [
        "#nll1 = samplers.compute_nll(diff,score_fn,key,sde_samples)\n",
        "nll2 = -samplers.discrete_time_likelihood(diff,score_fn,sde_samples[:5])/sde_samples[0].size"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "skdWWCH5BYVO"
      },
      "outputs": [],
      "source": [
        "nll2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0QBo0laDxPLo"
      },
      "outputs": [],
      "source": [
        "expanded = (mb[None]+jnp.zeros((10,1,1,1))).reshape(mb.shape[0]*10,*mb.shape[1:])#[:,slc]\n",
        "predictions = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,expanded[:,slc],slc,scale=300.),key,expanded.shape,N=2000,traj=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DpIsmHJ8z7vV"
      },
      "outputs": [],
      "source": [
        "sde_samples.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-OE4BPROthJG"
      },
      "outputs": [],
      "source": [
        "from jax import vmap\n",
        "#\n",
        "T = T_long\n",
        "z1 = sde_samples\n",
        "z2 = ode_samples\n",
        "z_gts = test_x[:z1.shape[0]]\n",
        "z0 = z_gts[:,0]#z_gts[:,0]\n",
        "#z0 = test_x\n",
        "#z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n",
        "for pred in [z1,z2,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,train.rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['SDE Diffusion Model Rollout','ode','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mh8Yn7rethMQ"
      },
      "outputs": [],
      "source": [
        "from jax import vmap\n",
        "#z_gts = test_x[:z1.shape[0]]\n",
        "T = T_long\n",
        "z = ode_samples\n",
        "z0 = z[:,0]#z_gts[:,0]\n",
        "#z0 = test_x\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n",
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,train.rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['SDE Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KtzsH_kSthO_"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SQ2NdxBNT7Y8"
      },
      "outputs": [],
      "source": [
        "from flax.core.frozen_dict import FrozenDict\n",
        "import numpy as np\n",
        "def sum_params(params):\n",
        "  if isinstance(params, (jax.numpy.ndarray,np.ndarray)):\n",
        "    return params.sum()\n",
        "  elif isinstance(params, (dict, FrozenDict)):\n",
        "    return sum([sum_params(v) for v in params.values()])\n",
        "  else:\n",
        "    assert False, type(params)\n",
        "print(sum_params(params))\n",
        "print(sum_params(p2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ObQ3g6EW4Ht"
      },
      "outputs": [],
      "source": [
        "type(None)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-zBdcsbTUK80"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Nxnr-3zUqHJ"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JsADCkg_VAbS"
      },
      "outputs": [],
      "source": [
        "type(p2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9bqrflL9bhu4"
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n",
        "jnp.exp(973.3657-977.17847)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nChZW4nUQY6Z"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
